# This code is modified from https://github.com/jakesnell/prototypical-networks 

import backbone
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
from methods.meta_template import MetaTemplate
from methods.min_norm_solvers import MinNormSolver, gradient_normalizers

class ProtoNet_MOML(MetaTemplate):
    def __init__(self, model_func,  n_way, n_support):
        super(ProtoNet_MOML, self).__init__( model_func,  n_way, n_support)
        self.loss_fn = nn.CrossEntropyLoss()
        self.weighting_mode = None
        
    def clamp(self, X, lower_limit, upper_limit):
        return torch.max(torch.min(X, upper_limit), lower_limit)
    
    def test_PGD(self, x_q, z_proto, step_num=7):
        x_q = x_q.cuda()
        labels = torch.from_numpy(np.repeat(range(self.n_way), self.n_query)).cuda()
        images = Variable(x_q, requires_grad=True).cuda()
        
        eps = 2/255 * torch.FloatTensor([1.0,1.0,1.0]).cuda()
        mean=  torch.FloatTensor([0.485, 0.456, 0.406]).cuda()
        std =  torch.FloatTensor([0.229, 0.224, 0.225]).cuda()
        epsilon = ((eps ) / std).reshape(3,1,1)
        upper_limit = torch.FloatTensor([2.2489, 2.4286, 2.6400]).reshape(3,1,1).cuda()
        lower_limit = torch.FloatTensor([-2.1179, -2.0357, -1.8044]).reshape(3,1,1).cuda()
        
        step_size = 1.5 / step_num * epsilon

        for i in range(step_num):
            _, loss = self.compute_q_loss(images, z_proto) 
            grad = torch.autograd.grad(loss, images, 
                                    retain_graph=False, create_graph=False)[0]
            grad = grad.detach().data
            adv_images = images.detach().data + step_size * torch.sign(grad)
            delta = self.clamp(adv_images - x_q, -epsilon, epsilon)
            adv_images = self.clamp(x_q + delta, lower_limit, upper_limit)
            images = Variable(adv_images, requires_grad=True).cuda()
        return images
    
    def compute_proto(self, x_s):
        x_s = x_s.cuda()
#         print(x_s.size(), x_s.size()[-2:])
        x_s = x_s.reshape(self.n_way*self.n_support, *x_s.size()[-3:])
        z_support = self.feature.forward(x_s)
        z_proto = z_support.reshape(self.n_way, self.n_support, -1 ).mean(1)
        return z_proto
    
    def compute_q_loss(self, x_q, z_proto):
        x_q = x_q.cuda()
        x_q = x_q.reshape(self.n_way*self.n_query, *x_q.size()[-3:])
        z_query = self.feature.forward(x_q)
        z_query = z_query.reshape(self.n_way* self.n_query, -1)
        dists = euclidean_dist(z_query, z_proto)
        y_query = torch.from_numpy(np.repeat(range(self.n_way), self.n_query)).cuda()
        return -dists, self.loss_fn(-dists, y_query)
    
    
    def train_loop(self, epoch, train_loader, optimizer):
        print_freq = 10

        avg_loss = 0
        avg_loss_adv = 0
        for i, (x,_ ) in enumerate(train_loader):
            self.n_query = x.size(1) - self.n_support           
            if self.change_way:
                self.n_way  = x.size(0)
            optimizer.zero_grad()
#             print('--', x.size())
            x_s = x[:,:self.n_support].reshape(self.n_way*self.n_support, *x.size()[2:])
            x_q = x[:,self.n_support:].reshape(self.n_way*self.n_query, *x.size()[2:])
            z_proto = self.compute_proto(x_s)
            x_q_adv = self.test_PGD(x_q, z_proto)
            x_q_adv.requires_grad = False
            _, q_loss = self.compute_q_loss(x_q, z_proto)
            _, q_loss_adv = self.compute_q_loss(x_q_adv, z_proto)
            if self.weighting_mode == 'MGDA':
                grads = {}
                grads['acc'] = list(torch.autograd.grad(q_loss, self.parameters(), retain_graph=True))
                grads['rob'] = list(torch.autograd.grad(q_loss_adv, self.parameters(), retain_graph=True))
                loss_data = {'acc': q_loss.item(), 'rob': q_loss_adv.item()}
                gn = gradient_normalizers(grads, loss_data, normalization_type='loss+')
#                 print(grads['acc'])
                for t in ['acc', 'rob']:
                    for k in range(len(grads[t])):
                        grads[t][k] = grads[t][k] / gn[t]
                sol, _ = MinNormSolver.find_min_norm_element([grads[t] for t in ['acc', 'rob']])
            elif self.weighting_mode == 'SOML':
                sol = [0.8, 0.2]
            elif self.weighting_mode == 'ORG':
                sol = [1, 0]
            else:
                print('!! no this weighting mode !!')
            loss = float(sol[0])*q_loss + float(sol[1])*q_loss_adv
            loss.backward()
            optimizer.step()
            avg_loss = avg_loss+q_loss.item()
            avg_loss_adv = avg_loss_adv+q_loss_adv.item()

            if i % print_freq==0:
                #print(optimizer.state_dict()['param_groups'][0]['lr'])
                print('Epoch {:d} | Batch {:d}/{:d} | Loss {:f} Loss ADV {:f}'.format(epoch, i, len(train_loader), avg_loss/float(i+1), avg_loss_adv/float(i+1)))
    
    def correct(self, x_q, z_proto):     
        with torch.no_grad():
            scores, _ = self.compute_q_loss(x_q, z_proto)
            y_query = np.repeat(range( self.n_way ), self.n_query)

            topk_scores, topk_labels = scores.data.topk(1, 1, True, True)
            topk_ind = topk_labels.cpu().numpy()
            top1_correct = np.sum(topk_ind[:,0] == y_query)
            return float(top1_correct), len(y_query)
                
    def test_loop(self, test_loader, record=None, return_std=False):
        correct =0
        count = 0
        acc_all = []
        acc_all_adv = []
        
        iter_num = len(test_loader) 
        for i, (x,_) in enumerate(test_loader):
            self.n_query = x.size(1) - self.n_support
            if self.change_way:
                self.n_way  = x.size(0)
            x_s = x[:,:self.n_support].reshape(self.n_way*self.n_support, *x.size()[2:])
            x_q = x[:,self.n_support:].reshape(self.n_way*self.n_query, *x.size()[2:])
            z_proto = self.compute_proto(x_s)
            x_q_adv = self.test_PGD(x_q, z_proto)
                
            correct_this, count_this = self.correct(x_q, z_proto)
            correct_this_adv, count_this_adv = self.correct(x_q_adv, z_proto)
            acc_all.append(correct_this/ count_this*100)
            acc_all_adv.append(correct_this_adv/ count_this_adv*100)

        acc = np.asarray(acc_all)
        acc_adv = np.asarray(acc_all_adv)
        B_score = 2 * (acc * acc_adv) / (acc + acc_adv)
        B_score2 = 2 * (np.mean(acc) * np.mean(acc_adv)) / (np.mean(acc) + np.mean(acc_adv))
        print('%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num, np.mean(np.asarray(acc_all)), 
                                                   1.96* np.std(np.asarray(acc_all))/np.sqrt(iter_num)))
        print('%d Test ADV Acc = %4.2f%% +- %4.2f%%' %(iter_num, np.mean(np.asarray(acc_all_adv)), 
                                                   1.96* np.std(np.asarray(acc_all_adv))/np.sqrt(iter_num)))
        print('%d Test B Acc = %4.2f%% and %4.2f%% +- %4.2f%%' %(iter_num, B_score2, np.mean(B_score), 
                                                   1.96* np.std(B_score)/np.sqrt(iter_num)))
        return np.mean(np.asarray(acc_all)),np.mean(np.asarray(acc_all_adv)), np.mean(B_score)

def euclidean_dist( x, y):
    # x: N x D
    # y: M x D
    n = x.size(0)
    m = y.size(0)
    d = x.size(1)
    assert d == y.size(1)

    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)

    return torch.pow(x - y, 2).sum(2)

